import shutil
import tempfile
import pybedtools
from numpy import *

from rpy2 import robjects, rinterface
import rpy2.robjects.numpy2ri
robjects.numpy2ri.activate()
from rpy2.robjects.packages import importr
deseq = importr('DESeq2')
print("Using DESeq2 version %s" % deseq.__version__)


timepoints = (0, 1, 4, 12, 24, 96)

filename = "peaks.HiSeq_StartSeq.expression.txt"
print("Reading", filename)
handle = open(filename)
line = next(handle)
words = line.split()
assert words[0] == 'peak'
libraries = words[1:]
counts = []
peak_names = []
for line in handle:
    words = line.split()
    peak_name = words[0]
    row = array(words[1:], int)
    counts.append(row)
    peak_names.append(peak_name)
handle.close()

counts = array(counts)
totals = sum(counts, 0)
conditions = []
for (library, total) in zip(libraries, totals):
    terms = library.split("_")
    dataset = terms[0]
    if dataset == "HiSeq":
        assert len(terms) == 3
        timepoint = terms[1]
        assert timepoint.startswith("t")
        timepoint = int(timepoint[1:])
        assert timepoint in timepoints
        replicate = terms[2]
        assert replicate in ("r1", "r2", "r3")
    elif dataset == "StartSeq":
        assert terms[1] in ("SRR7071452", "SRR7071453")
    else:
        raise Exception("Unknown dataset")
    conditions.append(dataset)
    print("Including library %s with total count %d" % (library, total))

estimateSizeFactors = robjects.r['estimateSizeFactors']
results = robjects.r['results']

reduced = robjects.Formula("~ 1")
metadata = {'dataset': robjects.StrVector(conditions)}
dataframe = robjects.DataFrame(metadata)
design = robjects.Formula("~ dataset")
dds = deseq.DESeqDataSetFromMatrix(countData=counts,
                                   colData=dataframe,
                                   design=design)
dds = estimateSizeFactors(dds)
dds = deseq.DESeq(dds, fitType="glmGamPoi", test="LRT", reduced=reduced)
res = results(dds)
output= res.do_slot('listData')
names = output.names
assert len(names)==6
assert names[0]=='baseMean'
assert names[1]=='log2FoldChange'
assert names[2]=='lfcSE'
assert names[3]=='stat'
assert names[4]=='pvalue'
assert names[5]=='padj'
basemeans = array(output[0])
log2fcs = array(output[1])
pvalues = array(output[5])

log2fcs[isnan(log2fcs)] = 0
pvalues[isnan(pvalues)] = 1

filename = "peaks.HiSeq_StartSeq.deseq.txt"
print("Writing", filename)
handle = open(filename, 'w')
handle.write("peak")
handle.write("\tbasemean")
handle.write("\tlog2fc")
handle.write("\tpvalue")
handle.write("\n")
for peak_name, basemean, log2fc, pvalue in zip(peak_names, basemeans, log2fcs, pvalues):
    handle.write(peak_name)
    handle.write("\t%g" % basemean)
    handle.write("\t%g" % log2fc)
    handle.write("\t%g" % pvalue)
    handle.write("\n")
handle.close()
